{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp interpret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.data.all import *\n",
    "from fastai.optimizer import *\n",
    "from fastai.learner import *\n",
    "from fastai.tabular.core import *\n",
    "import sklearn.metrics as skm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastai.test_utils import *\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interpretation of Predictions\n",
    "\n",
    "> Classes to build objects to better interpret predictions of a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastai.vision.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), \n",
    "                  get_items=get_image_files, \n",
    "                  splitter=RandomSubsetSplitter(.1,.1, seed=42),\n",
    "                  get_y=parent_label)\n",
    "test_dls = mnist.dataloaders(untar_data(URLs.MNIST_SAMPLE), bs=8)\n",
    "test_learner = vision_learner(test_dls, resnet18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def plot_top_losses(x, y, *args, **kwargs):\n",
    "    raise Exception(f\"plot_top_losses is not implemented for {type(x)},{type(y)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "_all_ = [\"plot_top_losses\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Interpretation():\n",
    "    \"Interpretation base class, can be inherited for task specific Interpretation classes\"\n",
    "    def __init__(self,\n",
    "        learn:Learner,\n",
    "        dl:DataLoader, # `DataLoader` to run inference over\n",
    "        losses:TensorBase, # Losses calculated from `dl`\n",
    "        act=None # Activation function for prediction\n",
    "    ): \n",
    "        store_attr()\n",
    "\n",
    "    def __getitem__(self, idxs):\n",
    "        \"Return inputs, preds, targs, decoded outputs, and losses at `idxs`\"\n",
    "        if isinstance(idxs, Tensor): idxs = idxs.tolist()\n",
    "        if not is_listy(idxs): idxs = [idxs]\n",
    "        items = getattr(self.dl.items, 'iloc', L(self.dl.items))[idxs]\n",
    "        tmp_dl = self.learn.dls.test_dl(items, with_labels=True, process=not isinstance(self.dl, TabDataLoader))\n",
    "        inps,preds,targs,decoded = self.learn.get_preds(dl=tmp_dl, with_input=True, with_loss=False, \n",
    "                                                        with_decoded=True, act=self.act, reorder=False)\n",
    "        return inps, preds, targs, decoded, self.losses[idxs]\n",
    "\n",
    "    @classmethod\n",
    "    def from_learner(cls,\n",
    "        learn, # Model used to create interpretation\n",
    "        ds_idx:int=1, # Index of `learn.dls` when `dl` is None\n",
    "        dl:DataLoader=None, # `Dataloader` used to make predictions\n",
    "        act=None # Override default or set prediction activation function\n",
    "    ):\n",
    "        \"Construct interpretation object from a learner\"\n",
    "        if dl is None: dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)\n",
    "        _,_,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=False,\n",
    "                                     with_preds=False, with_targs=False, act=act)\n",
    "        return cls(learn, dl, losses, act)\n",
    "\n",
    "    def top_losses(self,\n",
    "        k:int|None=None, # Return `k` losses, defaults to all\n",
    "        largest:bool=True, # Sort losses by largest or smallest\n",
    "        items:bool=False # Whether to return input items\n",
    "    ):\n",
    "        \"`k` largest(/smallest) losses and indexes, defaulting to all losses.\"\n",
    "        losses, idx = self.losses.topk(ifnone(k, len(self.losses)), largest=largest)\n",
    "        if items: return losses, idx, getattr(self.dl.items, 'iloc', L(self.dl.items))[idx]\n",
    "        else:     return losses, idx\n",
    "\n",
    "    def plot_top_losses(self,\n",
    "        k:int|MutableSequence, # Number of losses to plot\n",
    "        largest:bool=True, # Sort losses by largest or smallest\n",
    "        **kwargs\n",
    "    ):\n",
    "        \"Show `k` largest(/smallest) preds and losses. Implementation based on type dispatch\"\n",
    "        if is_listy(k) or isinstance(k, range):\n",
    "            losses, idx = (o[k] for o in self.top_losses(None, largest))\n",
    "        else: \n",
    "            losses, idx = self.top_losses(k, largest)\n",
    "        inps, preds, targs, decoded, _ = self[idx]\n",
    "        inps, targs, decoded = tuplify(inps), tuplify(targs), tuplify(decoded)\n",
    "        x, y, its = self.dl._pre_show_batch(inps+targs, max_n=len(idx))\n",
    "        x1, y1, outs = self.dl._pre_show_batch(inps+decoded, max_n=len(idx))\n",
    "        if its is not None:\n",
    "            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses, **kwargs)\n",
    "        #TODO: figure out if this is needed\n",
    "        #its None means that a batch knows how to show itself as a whole, so we pass x, x1\n",
    "        #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)\n",
    "\n",
    "    def show_results(self,\n",
    "        idxs:list, # Indices of predictions and targets\n",
    "        **kwargs\n",
    "    ):\n",
    "        \"Show predictions and targets of `idxs`\"\n",
    "        if isinstance(idxs, Tensor): idxs = idxs.tolist()\n",
    "        if not is_listy(idxs): idxs = [idxs]\n",
    "        inps, _, targs, decoded, _ = self[idxs]\n",
    "        b = tuplify(inps)+tuplify(targs)\n",
    "        self.dl.show_results(b, tuplify(decoded), max_n=len(idxs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"Interpretation\" class=\"doc_header\"><code>class</code> <code>Interpretation</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>Interpretation</code>(**`learn`**:[`Learner`](/learner.html#Learner), **`dl`**:[`DataLoader`](/data.load.html#DataLoader), **`losses`**:[`TensorBase`](/torch_core.html#TensorBase), **`act`**=*`None`*)\n",
       "\n",
       "Interpretation base class, can be inherited for task specific Interpretation classes\n",
       "\n",
       "||Type|Default|Details|\n",
       "|---|---|---|---|\n",
       "|**`learn`**|[`Learner`](/learner.html#Learner)||*No Content*|\n",
       "|**`dl`**|[`DataLoader`](/data.load.html#DataLoader)||[`DataLoader`](/data.load.html#DataLoader) to run inference over|\n",
       "|**[`losses`](/losses.html)**|[`TensorBase`](/torch_core.html#TensorBase)||Losses calculated from `dl`|\n",
       "|**`act`**|`NoneType`|`None`|Activation function for prediction|\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Interpretation, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Interpretation` is a helper base class for exploring predictions from trained models. It can be inherited for task specific interpretation classes, such as `ClassificationInterpretation`. `Interpretation` is memory efficient and should be able to process any sized dataset, provided the hardware could train the same model.\n",
    "\n",
    ":::{.callout-note}\n",
    "\n",
    "`Interpretation` is memory efficient due to generating inputs, predictions, targets, decoded outputs, and losses for each item on the fly, using batch processing where possible.\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"Interpretation.from_learner\" class=\"doc_header\"><code>Interpretation.from_learner</code><a href=\"__main__.py#L22\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>Interpretation.from_learner</code>(**`learn`**, **`ds_idx`**:`int`=*`1`*, **`dl`**:[`DataLoader`](/data.load.html#DataLoader)=*`None`*, **`act`**=*`None`*)\n",
       "\n",
       "Construct interpretation object from a learner\n",
       "\n",
       "||Type|Default|Details|\n",
       "|---|---|---|---|\n",
       "|**`learn`**|||Model used to create interpretation|\n",
       "|**`ds_idx`**|`int`|`1`|Index of `learn.dls` when `dl` is None|\n",
       "|**`dl`**|[`DataLoader`](/data.load.html#DataLoader)|`None`|`Dataloader` used to make predictions|\n",
       "|**`act`**|`NoneType`|`None`|Override default or set prediction activation function|\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Interpretation.from_learner, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"Interpretation.top_losses\" class=\"doc_header\"><code>Interpretation.top_losses</code><a href=\"__main__.py#L35\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>Interpretation.top_losses</code>(**`k`**:`(<class 'int'>, None)`=*`None`*, **`largest`**:`bool`=*`True`*, **`items`**:`bool`=*`False`*)\n",
       "\n",
       "`k` largest(/smallest) losses and indexes, defaulting to all losses.\n",
       "\n",
       "||Type|Default|Details|\n",
       "|---|---|---|---|\n",
       "|**`k`**|`(int, None)`|`None`|Return `k` losses, defaults to all|\n",
       "|**`largest`**|`bool`|`True`|Sort losses by largest or smallest|\n",
       "|**`items`**|`bool`|`False`|Whether to return input items|\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Interpretation.top_losses, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the default of `k=None`, `top_losses` will return the entire dataset's losses. `top_losses` can optionally include the input items for each loss, which is usually a file path or Pandas `DataFrame`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"Interpretation.plot_top_losses\" class=\"doc_header\"><code>Interpretation.plot_top_losses</code><a href=\"__main__.py#L45\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>Interpretation.plot_top_losses</code>(**`k`**:`(<class 'int'>, <class 'list'>)`, **`largest`**:`bool`=*`True`*, **\\*\\*`kwargs`**)\n",
       "\n",
       "Show `k` largest(/smallest) preds and losses. Implementation based on type dispatch\n",
       "\n",
       "||Type|Default|Details|\n",
       "|---|---|---|---|\n",
       "|**`k`**|`(int, list)`||Number of losses to plot|\n",
       "|**`largest`**|`bool`|`True`|Sort losses by largest or smallest|\n",
       "|**`kwargs`**|||*No Content*|\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Interpretation.plot_top_losses, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To plot the first 9 top losses:\n",
    "```python\n",
    "interp = Interpretation.from_learner(learn)\n",
    "interp.plot_top_losses(9)\n",
    "```\n",
    "Then to plot the 7th through 16th top losses:\n",
    "```python\n",
    "interp.plot_top_losses(range(7,16))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"Interpretation.show_results\" class=\"doc_header\"><code>Interpretation.show_results</code><a href=\"__main__.py#L65\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>Interpretation.show_results</code>(**`idxs`**:`list`, **\\*\\*`kwargs`**)\n",
       "\n",
       "Show predictions and targets of `idxs`\n",
       "\n",
       "||Type|Default|Details|\n",
       "|---|---|---|---|\n",
       "|**`idxs`**|`list`||Indices of predictions and targets|\n",
       "|**`kwargs`**|||*No Content*|\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Interpretation.show_results, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Like `Learner.show_results`, except can pass desired index or indicies for item(s) to show results from."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#| hide\n",
    "interp = Interpretation.from_learner(test_learner)\n",
    "x, y, out = [], [], []\n",
    "for batch in test_learner.dls.valid:\n",
    "    x += batch[0]\n",
    "    y += batch[1]\n",
    "    out += test_learner.model(batch[0])\n",
    "x,y,out = torch.stack(x), torch.stack(y, dim=0), torch.stack(out, dim=0)\n",
    "inps, preds, targs, decoded, losses = interp[:]\n",
    "test_eq(inps, to_cpu(x))\n",
    "test_eq(targs, to_cpu(y))\n",
    "loss = torch.stack([test_learner.loss_func(p,t) for p,t in zip(out,y)], dim=0)\n",
    "test_close(losses, to_cpu(loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#| hide\n",
    "# verify stored losses equal calculated losses for idx\n",
    "top_losses, idx = interp.top_losses(9)\n",
    "\n",
    "dl = test_learner.dls[1].new(shuffle=False, drop_last=False)\n",
    "items = getattr(dl.items, 'iloc', L(dl.items))[idx]\n",
    "tmp_dl = test_learner.dls.test_dl(items, with_labels=True, process=not isinstance(dl, TabDataLoader))\n",
    "_, _, _, _, losses = test_learner.get_preds(dl=tmp_dl, with_input=True, with_loss=True, \n",
    "                                            with_decoded=True, act=None, reorder=False)\n",
    "\n",
    "test_close(top_losses, losses, 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#| hide\n",
    "#dummy test to ensure we can run on the training set\n",
    "interp = Interpretation.from_learner(test_learner, ds_idx=0)\n",
    "x, y, out = [], [], []\n",
    "for batch in test_learner.dls.train.new(drop_last=False, shuffle=False):\n",
    "    x += batch[0]\n",
    "    y += batch[1]\n",
    "    out += test_learner.model(batch[0])\n",
    "x,y,out = torch.stack(x), torch.stack(y, dim=0), torch.stack(out, dim=0)\n",
    "inps, preds, targs, decoded, losses = interp[:]\n",
    "test_eq(inps, to_cpu(x))\n",
    "test_eq(targs, to_cpu(y))\n",
    "loss = torch.stack([test_learner.loss_func(p,t) for p,t in zip(out,y)], dim=0)\n",
    "test_close(losses, to_cpu(loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class ClassificationInterpretation(Interpretation):\n",
    "    \"Interpretation methods for classification models.\"\n",
    "\n",
    "    def __init__(self, \n",
    "        learn:Learner, \n",
    "        dl:DataLoader, # `DataLoader` to run inference over\n",
    "        losses:TensorBase, # Losses calculated from `dl`\n",
    "        act=None # Activation function for prediction\n",
    "    ):\n",
    "        super().__init__(learn, dl, losses, act)\n",
    "        self.vocab = self.dl.vocab\n",
    "        if is_listy(self.vocab): self.vocab = self.vocab[-1]\n",
    "\n",
    "    def confusion_matrix(self):\n",
    "        \"Confusion matrix as an `np.ndarray`.\"\n",
    "        x = torch.arange(0, len(self.vocab))\n",
    "        _,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True, \n",
    "                                               with_targs=True, act=self.act)\n",
    "        d,t = flatten_check(decoded, targs)\n",
    "        cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)\n",
    "        return to_np(cm)\n",
    "\n",
    "    def plot_confusion_matrix(self, \n",
    "        normalize:bool=False, # Whether to normalize occurrences\n",
    "        title:str='Confusion matrix', # Title of plot\n",
    "        cmap:str=\"Blues\", # Colormap from matplotlib\n",
    "        norm_dec:int=2, # Decimal places for normalized occurrences\n",
    "        plot_txt:bool=True, # Display occurrence in matrix\n",
    "        **kwargs\n",
    "    ):\n",
    "        \"Plot the confusion matrix, with `title` and using `cmap`.\"\n",
    "        # This function is mainly copied from the sklearn docs\n",
    "        cm = self.confusion_matrix()\n",
    "        if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
    "        fig = plt.figure(**kwargs)\n",
    "        plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
    "        plt.title(title)\n",
    "        tick_marks = np.arange(len(self.vocab))\n",
    "        plt.xticks(tick_marks, self.vocab, rotation=90)\n",
    "        plt.yticks(tick_marks, self.vocab, rotation=0)\n",
    "\n",
    "        if plot_txt:\n",
    "            thresh = cm.max() / 2.\n",
    "            for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
    "                coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'\n",
    "                plt.text(j, i, coeff, horizontalalignment=\"center\", verticalalignment=\"center\", color=\"white\"\n",
    "                         if cm[i, j] > thresh else \"black\")\n",
    "\n",
    "        ax = fig.gca()\n",
    "        ax.set_ylim(len(self.vocab)-.5,-.5)\n",
    "\n",
    "        plt.tight_layout()\n",
    "        plt.ylabel('Actual')\n",
    "        plt.xlabel('Predicted')\n",
    "        plt.grid(False)\n",
    "\n",
    "    def most_confused(self, min_val=1):\n",
    "        \"Sorted descending largest non-diagonal entries of confusion matrix (actual, predicted, # occurrences\"\n",
    "        cm = self.confusion_matrix()\n",
    "        np.fill_diagonal(cm, 0)\n",
    "        res = [(self.vocab[i],self.vocab[j],cm[i,j]) for i,j in zip(*np.where(cm>=min_val))]\n",
    "        return sorted(res, key=itemgetter(2), reverse=True)\n",
    "\n",
    "    def print_classification_report(self):\n",
    "        \"Print scikit-learn classification report\"\n",
    "        _,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True, \n",
    "                                               with_targs=True, act=self.act)\n",
    "        d,t = flatten_check(decoded, targs)\n",
    "        names = [str(v) for v in self.vocab]\n",
    "        print(skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"ClassificationInterpretation.confusion_matrix\" class=\"doc_header\"><code>ClassificationInterpretation.confusion_matrix</code><a href=\"__main__.py#L10\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>ClassificationInterpretation.confusion_matrix</code>()\n",
       "\n",
       "Confusion matrix as an `np.ndarray`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ClassificationInterpretation.confusion_matrix, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/interpret.py#L122){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### ClassificationInterpretation.plot_confusion_matrix\n",
       "\n",
       ">      ClassificationInterpretation.plot_confusion_matrix (normalize:bool=False,\n",
       ">                                                          title:str='Confusion\n",
       ">                                                          matrix',\n",
       ">                                                          cmap:str='Blues',\n",
       ">                                                          norm_dec:int=2,\n",
       ">                                                          plot_txt:bool=True,\n",
       ">                                                          **kwargs)\n",
       "\n",
       "*Plot the confusion matrix, with `title` and using `cmap`.*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| normalize | bool | False | Whether to normalize occurrences |\n",
       "| title | str | Confusion matrix | Title of plot |\n",
       "| cmap | str | Blues | Colormap from matplotlib |\n",
       "| norm_dec | int | 2 | Decimal places for normalized occurrences |\n",
       "| plot_txt | bool | True | Display occurrence in matrix |\n",
       "| kwargs | VAR_KEYWORD |  |  |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/interpret.py#L122){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### ClassificationInterpretation.plot_confusion_matrix\n",
       "\n",
       ">      ClassificationInterpretation.plot_confusion_matrix (normalize:bool=False,\n",
       ">                                                          title:str='Confusion\n",
       ">                                                          matrix',\n",
       ">                                                          cmap:str='Blues',\n",
       ">                                                          norm_dec:int=2,\n",
       ">                                                          plot_txt:bool=True,\n",
       ">                                                          **kwargs)\n",
       "\n",
       "*Plot the confusion matrix, with `title` and using `cmap`.*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| normalize | bool | False | Whether to normalize occurrences |\n",
       "| title | str | Confusion matrix | Title of plot |\n",
       "| cmap | str | Blues | Colormap from matplotlib |\n",
       "| norm_dec | int | 2 | Decimal places for normalized occurrences |\n",
       "| plot_txt | bool | True | Display occurrence in matrix |\n",
       "| kwargs | VAR_KEYWORD |  |  |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(ClassificationInterpretation.plot_confusion_matrix, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"ClassificationInterpretation.most_confused\" class=\"doc_header\"><code>ClassificationInterpretation.most_confused</code><a href=\"__main__.py#L47\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>ClassificationInterpretation.most_confused</code>(**`min_val`**=*`1`*)\n",
       "\n",
       "Sorted descending largest non-diagonal entries of confusion matrix (actual, predicted, # occurrences"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ClassificationInterpretation.most_confused, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#| hide\n",
    "# simple test to make sure ClassificationInterpretation works\n",
    "interp = ClassificationInterpretation.from_learner(test_learner)\n",
    "cm = interp.confusion_matrix()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class SegmentationInterpretation(Interpretation):\n",
    "    \"Interpretation methods for segmentation models.\"\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
